{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321dd425",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import shutil\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import onekey_algo.custom.components as okcomp\n",
    "from onekey_algo import get_param_in_cwd\n",
    "\n",
    "os.makedirs('img', exist_ok=True)\n",
    "plt.rcParams['figure.dpi'] = 300\n",
    "model_names = get_param_in_cwd('compare_model')\n",
    "# 获取配置\n",
    "task = get_param_in_cwd('task_column') or 'label'\n",
    "labelf = get_param_in_cwd('label_file')\n",
    "results_dir = get_param_in_cwd('results_dir')\n",
    "group_info = get_param_in_cwd('dataset_column')\n",
    "\n",
    "# 读取label文件。\n",
    "labels = [task]\n",
    "label_data_ = pd.read_csv(labelf)\n",
    "ids = label_data_['ID']\n",
    "print(label_data_.columns)\n",
    "label_data = label_data_[['ID'] + labels]\n",
    "label_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac5ee2cf",
   "metadata": {},
   "source": [
    "# 训练集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5002786",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "subset = 'train'\n",
    "ALL_results = None\n",
    "for mn in model_names:\n",
    "    r = pd.read_csv(os.path.join(results_dir, f'{mn}_{subset}.csv'))\n",
    "    r.columns = ['ID', '-0', mn]\n",
    "    if ALL_results is None:\n",
    "        ALL_results = r\n",
    "    else:\n",
    "        ALL_results = pd.merge(ALL_results, r, on='ID', how='inner')\n",
    "\n",
    "ALL_results = pd.merge(ALL_results, label_data, on='ID', how='inner')\n",
    "\n",
    "ALL_results = ALL_results.dropna(axis=1)\n",
    "ALL_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e04731e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "pred_column = [f'{task}-0', f'{task}-1']\n",
    "gt = [np.array(ALL_results[task]) for d in model_names]\n",
    "pred_train = [np.array(ALL_results[d]) for d in model_names]\n",
    "okcomp.comp1.draw_roc(gt, pred_train, labels=model_names, title=f\"Model AUC\")\n",
    "plt.savefig(f'img/compare_{subset}_auc.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4826febd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.metrics import analysis_pred_binary\n",
    "metric = []\n",
    "for mname, y, score in zip(model_names, gt, pred_train):\n",
    "    # 计算验证集指标\n",
    "    acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres = analysis_pred_binary(y, score)\n",
    "    ci = f\"{ci[0]:.4f} - {ci[1]:.4f}\"\n",
    "    metric.append((mname, acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres, f\"Train\"))\n",
    "pd.DataFrame(metric, index=None, columns=['Signature', 'Accuracy', 'AUC', '95% CI', 'Sensitivity', 'Specificity', \n",
    "                                          'PPV', 'NPV', 'Precision', 'Recall', 'F1','Threshold', 'Cohort'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69899f07",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.delong import delong_roc_test\n",
    "from onekey_algo.custom.components.comp1 import draw_matrix\n",
    "\n",
    "delong = []\n",
    "delong_columns = []\n",
    "this_delong = []\n",
    "plt.figure(figsize=(5, 4))\n",
    "cm = np.zeros((len(model_names), len(model_names)))\n",
    "for i, mni in enumerate(model_names):\n",
    "    for j, mnj in enumerate(model_names):\n",
    "        if i <= j:\n",
    "            cm[i][j] = np.nan\n",
    "        else:\n",
    "            cm[i][j] = delong_roc_test(ALL_results[task], ALL_results[mni], ALL_results[mnj])[0][0]\n",
    "cm = pd.DataFrame(cm[1:, :-1], index=model_names[1:], columns=model_names[:-1])\n",
    "draw_matrix(cm, annot=True, cmap='jet_r', cbar=True)\n",
    "plt.title(f'Cohort {subset} Delong')\n",
    "plt.savefig(f'img/compare_delong_each_cohort_{subset}.svg', bbox_inches = 'tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4045192a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import plot_DCA\n",
    "plot_DCA([ALL_results[mn] for mn in model_names], ALL_results[task], title=f'Model for DCA', labels=model_names, y_min=-0.15)\n",
    "plt.savefig(f'img/compare_{subset}_dca.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40c40079",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import draw_calibration\n",
    "draw_calibration(pred_scores=pred_train, n_bins=5, y_test=gt, model_names=model_names)\n",
    "plt.savefig(f'img/compare_{subset}_cali.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ec7e842",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components import stats\n",
    "\n",
    "hosmer = []\n",
    "hosmer.append([stats.hosmer_lemeshow_test(y_true, y_pred, bins=15) \n",
    "              for fn, y_true, y_pred in zip(model_names, gt, pred_train)])\n",
    "pd.DataFrame(hosmer, columns=model_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdd19d5c",
   "metadata": {},
   "source": [
    "# 测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5e08e9d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "subset = 'test'\n",
    "ALL_results = None\n",
    "for mn in model_names:\n",
    "    r = pd.read_csv(os.path.join(results_dir, f'{mn}_{subset}.csv'))\n",
    "    r.columns = ['ID', '-0', mn]\n",
    "    if ALL_results is None:\n",
    "        ALL_results = r\n",
    "    else:\n",
    "        ALL_results = pd.merge(ALL_results, r, on='ID', how='inner')\n",
    "\n",
    "ALL_results = pd.merge(ALL_results, label_data, on='ID', how='inner')\n",
    "\n",
    "ALL_results = ALL_results.dropna(axis=1)\n",
    "ALL_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a67328a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_column = [f'{task}-0', f'{task}-1']\n",
    "gt = [np.array(ALL_results[task]) for d in model_names]\n",
    "pred_train = [np.array(ALL_results[d]) for d in model_names]\n",
    "okcomp.comp1.draw_roc(gt, pred_train, labels=model_names, title=f\"Model AUC\")\n",
    "plt.savefig(f'img/compare_{subset}_auc.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc97fa50",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.metrics import analysis_pred_binary\n",
    "for mname, y, score in zip(model_names, gt, pred_train):\n",
    "    # 计算验证集指标\n",
    "    acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres = analysis_pred_binary(y, score)\n",
    "    ci = f\"{ci[0]:.4f} - {ci[1]:.4f}\"\n",
    "    metric.append((mname, acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres, f\"Test\"))\n",
    "metric = pd.DataFrame(metric, index=None, columns=['Signature', 'Accuracy', 'AUC', '95% CI',\n",
    "                                                   'Sensitivity', 'Specificity', \n",
    "                                                   'PPV', 'NPV', 'Precision', 'Recall', 'F1',\n",
    "                                                   'Threshold', 'Cohort'])\n",
    "\n",
    "metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff64306d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.delong import delong_roc_test\n",
    "from onekey_algo.custom.components.comp1 import draw_matrix\n",
    "\n",
    "delong = []\n",
    "delong_columns = []\n",
    "this_delong = []\n",
    "plt.figure(figsize=(5, 4))\n",
    "cm = np.zeros((len(model_names), len(model_names)))\n",
    "for i, mni in enumerate(model_names):\n",
    "    for j, mnj in enumerate(model_names):\n",
    "        if i <= j:\n",
    "            cm[i][j] = np.nan\n",
    "        else:\n",
    "            cm[i][j] = delong_roc_test(ALL_results[task], ALL_results[mni], ALL_results[mnj])[0][0]\n",
    "cm = pd.DataFrame(cm[1:, :-1], index=model_names[1:], columns=model_names[:-1])\n",
    "draw_matrix(cm, annot=True, cmap='jet_r', cbar=True)\n",
    "plt.title(f'Cohort {subset} Delong')\n",
    "plt.savefig(f'img/compare_delong_each_cohort_{subset}.svg', bbox_inches = 'tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cadf93ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import plot_DCA\n",
    "plot_DCA([ALL_results[mn] for mn in model_names], ALL_results[task], title=f'Model for DCA', labels=model_names, y_min=-0.15)\n",
    "plt.savefig(f'img/compare_{subset}_dca.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e5f2d88",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import draw_calibration\n",
    "draw_calibration(pred_scores=pred_train, n_bins=5, y_test=gt, model_names=model_names)\n",
    "plt.savefig(f'img/compare_{subset}_cali.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9aebe4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components import stats\n",
    "\n",
    "hosmer.append([stats.hosmer_lemeshow_test(y_true, y_pred, bins=5) \n",
    "              for fn, y_true, y_pred in zip(model_names, gt, pred_train)])\n",
    "pd.DataFrame(hosmer, columns=model_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9578979",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
